from data_juicer.ops.base_op import OPERATORS, Mapper
from data_juicer.ops.op_fusion import LOADED_IMAGES
from data_juicer.utils.lazy_loader import LazyLoader
from data_juicer.utils.mm_utils import load_image
from data_juicer.utils.model_utils import get_model, prepare_model

OP_NAME = 'mllm_mapper'

torch = LazyLoader('torch', 'torch')
transformers = LazyLoader('transformers', 'transformers')
torch.set_num_threads(1)


@LOADED_IMAGES.register_module(OP_NAME)
@OPERATORS.register_module(OP_NAME)
class MllmMapper(Mapper):
    """Mapper to use MLLMs for visual question answering tasks.
    Recommended model list: [
        llava-hf/llava-v1.6-vicuna-7b-hf,
        Qwen/Qwen2-VL-7B-Instruct,
    ]
    """
    _accelerator = 'cuda'

    def __init__(self,
                 hf_model: str = 'llava-hf/llava-v1.6-vicuna-7b-hf',
                 max_new_tokens=256,
                 temperature=0.2,
                 top_p=None,
                 num_beams=1,
                 *args,
                 **kwargs):
        """
        Initialization method.
        :param hf_model: hugginface model id.
        :param max_new_tokens: the maximum number of new tokens
            generated by the model.
        :param temperature: used to control the randomness of \
            generated text. The higher the temperature, the more \
                random and creative the generated text will be.
        :param top_p: randomly select the next word from the group \
            of words whose cumulative probability reaches p.
        :param num_beams: the larger the beam search size, the higher \
            the quality of the generated text.
        :param args: extra args
        :param kwargs: extra args
        """
        kwargs.setdefault('mem_required', '32GB')
        kwargs.setdefault('num_proc', 1)
        super().__init__(*args, **kwargs)

        self.hf_model = hf_model
        self.model_key = prepare_model(model_type='huggingface',
                                       pretrained_model_name_or_path=hf_model)
        self.max_new_tokens = max_new_tokens
        self.temperature = temperature
        self.top_p = top_p
        self.num_beams = num_beams

    def process_single(self, sample=None, rank=None):

        # there is no image in this sample
        if self.image_key not in sample or not sample[self.image_key]:
            return sample

        # load images
        loaded_image_keys = sample[self.image_key]
        images = {}
        for loaded_image_key in loaded_image_keys:
            if loaded_image_key not in images:
                # avoid loading the same images
                image = load_image(loaded_image_key)
                images[loaded_image_key] = image

        model, processor = get_model(model_key=self.model_key,
                                     rank=rank,
                                     use_cuda=self.use_cuda())

        conversation = [
            {
                'role':
                'user',
                'content': [
                    {
                        'type': 'text',
                        'text': sample[self.text_key]
                    },
                    {
                        'type': 'image'
                    },
                ],
            },
        ]
        prompt = processor.apply_chat_template(conversation,
                                               add_generation_prompt=True)

        sample[self.text_key] = []

        for image_key in images:
            inputs = processor(images=images[image_key],
                               text=prompt,
                               return_tensors='pt').to(model.device)

            response = model.generate(**inputs,
                                      max_new_tokens=self.max_new_tokens,
                                      temperature=self.temperature,
                                      top_p=self.top_p,
                                      num_beams=self.num_beams)

            output = processor.decode(response.cpu()[0],
                                      skip_special_tokens=True)

            sample[self.text_key].append(output)

        return sample
